{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "f_KNv25DX7B6"
   },
   "source": [
    "#[Super SloMo](https://people.cs.umass.edu/~hzjiang/projects/superslomo/)\n",
    "##High Quality Estimation of Multiple Intermediate Frames for Video Interpolation\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "0VWuBGh6zMMZ"
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import model\n",
    "import dataloader\n",
    "import matplotlib.pyplot as plt\n",
    "from math import log10\n",
    "from IPython.display import clear_output, display\n",
    "import datetime\n",
    "from tensorboardX import SummaryWriter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "1VynXmoKp_3M"
   },
   "source": [
    "##Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "N2yrOVZjqDe9"
   },
   "outputs": [],
   "source": [
    "# Learning Rate. Set `MILESTONES` to epoch values where you want to decrease\n",
    "# learning rate by a factor of 0.1\n",
    "INITIAL_LEARNING_RATE = 0.0001\n",
    "MILESTONES = [100, 150]\n",
    "\n",
    "# Number of epochs to train\n",
    "EPOCHS = 200\n",
    "\n",
    "# Choose batchsize as per GPU/CPU configuration\n",
    "# This configuration works on GTX 1080 Ti\n",
    "TRAIN_BATCH_SIZE = 6\n",
    "VALIDATION_BATCH_SIZE = 10\n",
    "\n",
    "# Path to dataset folder containing train-test-validation folders\n",
    "DATASET_ROOT = \"path/to/dataset\"\n",
    "\n",
    "# Path to folder for saving checkpoints\n",
    "CHECKPOINT_DIR = 'path/to/checkpoint_directory'\n",
    "\n",
    "# If resuming from checkpoint, set `trainingContinue` to True and set `checkpoint_path`\n",
    "TRAINING_CONTINUE = False\n",
    "CHECKPOINT_PATH = 'path/to/checkpoint/file'\n",
    "\n",
    "# Progress and validation frequency (N: after every N iterations)\n",
    "PROGRESS_ITER = 100\n",
    "\n",
    "# Checkpoint frequency (N: after every N epochs). Each checkpoint is roughly of size 151 MB.\n",
    "CHECKPOINT_EPOCH = 5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Yr3Lm1ovbWv1"
   },
   "source": [
    "##[TensorboardX](https://github.com/lanpa/tensorboardX)\n",
    "### For visualizing loss and interpolated frames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "saUJTMiMCAzH"
   },
   "outputs": [],
   "source": [
    "writer = SummaryWriter('log')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Ua1DJm82aj5-"
   },
   "source": [
    "###Initialize flow computation and arbitrary-time flow interpolation CNNs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "D42vzEKrWtpG"
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "flowComp = model.UNet(6, 4)\n",
    "flowComp.to(device)\n",
    "ArbTimeFlowIntrp = model.UNet(20, 5)\n",
    "ArbTimeFlowIntrp.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UYMpk2EYchaY"
   },
   "source": [
    "###Initialze backward warpers for train and validation datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "vJq6SrWIf2GE"
   },
   "outputs": [],
   "source": [
    "trainFlowBackWarp      = model.backWarp(352, 352, device)\n",
    "trainFlowBackWarp      = trainFlowBackWarp.to(device)\n",
    "validationFlowBackWarp = model.backWarp(640, 352, device)\n",
    "validationFlowBackWarp = validationFlowBackWarp.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "oSs9UaIjdTT2"
   },
   "source": [
    "###Load Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "MJ9cVigEgtyT"
   },
   "outputs": [],
   "source": [
    "# Channel wise mean calculated on adobe240-fps training dataset\n",
    "mean = [0.429, 0.431, 0.397]\n",
    "std  = [1, 1, 1]\n",
    "normalize = transforms.Normalize(mean=mean,\n",
    "                                 std=std)\n",
    "transform = transforms.Compose([transforms.ToTensor(), normalize])\n",
    "\n",
    "trainset = dataloader.SuperSloMo(root=DATASET_ROOT + '/train', transform=transform, train=True)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)\n",
    "\n",
    "validationset = dataloader.SuperSloMo(root=DATASET_ROOT + '/validation', transform=transform, randomCropSize=(640, 352), train=False)\n",
    "validationloader = torch.utils.data.DataLoader(validationset, batch_size=VALIDATION_BATCH_SIZE, shuffle=False)\n",
    "\n",
    "print(trainset, validationset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "WXmNMdbJfp2d"
   },
   "source": [
    "###Create transform to display image from tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "try3adPHgwse"
   },
   "outputs": [],
   "source": [
    "negmean = [x * -1 for x in mean]\n",
    "revNormalize = transforms.Normalize(mean=negmean, std=std)\n",
    "TP = transforms.Compose([revNormalize, transforms.ToPILImage()])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "32XZg9Mfd5bN"
   },
   "source": [
    "###Test the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "0Vyf7dbwCO1E"
   },
   "outputs": [],
   "source": [
    "for trainIndex, (trainData, frameIndex) in enumerate(trainloader, 0):\n",
    "    frame0, frameT, frame1 = trainData\n",
    "    print(\"Intermediate frame index: \", (frameIndex[0]))\n",
    "    plt.imshow(TP(frame0[0]))\n",
    "    plt.grid(True)\n",
    "    plt.figure()\n",
    "    plt.imshow(TP(frameT[0]))\n",
    "    plt.grid(True)\n",
    "    plt.figure()\n",
    "    plt.imshow(TP(frame1[0]))\n",
    "    plt.grid(True)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rh0MK2qKuBlV"
   },
   "source": [
    "###Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "BdMFU0ijfIuI"
   },
   "outputs": [],
   "source": [
    "plt.rcParams['figure.figsize'] = [15, 3]\n",
    "def Plot(num, listInp, d):\n",
    "    a = listInp\n",
    "    c = []\n",
    "    for b in a:\n",
    "        c.append(sum(b)/len(b))\n",
    "    plt.subplot(1, 2, num)\n",
    "    plt.plot(c, color=d)\n",
    "    plt.grid(True)\n",
    "    \n",
    "def get_lr(optimizer):\n",
    "    for param_group in optimizer.param_groups:\n",
    "        return param_group['lr']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "mooLcmxtpPR_"
   },
   "source": [
    "###Loss and Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "BuWQfcb-jhWx"
   },
   "outputs": [],
   "source": [
    "L1_lossFn = nn.L1Loss()\n",
    "MSE_LossFn = nn.MSELoss()\n",
    "\n",
    "params = list(ArbTimeFlowIntrp.parameters()) + list(flowComp.parameters())\n",
    "\n",
    "optimizer = optim.Adam(params, lr=INITIAL_LEARNING_RATE)\n",
    "# scheduler to decrease learning rate by a factor of 10 at milestones.\n",
    "scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=MILESTONES, gamma=0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "a5rIkwwfpk1n"
   },
   "source": [
    "###Initializing VGG16 model for perceptual loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "9WR_NxHP51oB"
   },
   "outputs": [],
   "source": [
    "vgg16 = torchvision.models.vgg16(pretrained=True)\n",
    "vgg16_conv_4_3 = nn.Sequential(*list(vgg16.children())[0][:22])\n",
    "vgg16_conv_4_3.to(device)\n",
    "for param in vgg16_conv_4_3.parameters():\n",
    "\t\tparam.requires_grad = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "9-6wLaBJZqsm"
   },
   "source": [
    "### Validation function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "RhMMZ_I4iDFf"
   },
   "outputs": [],
   "source": [
    "def validate():\n",
    "    # For details see training.\n",
    "    psnr = 0\n",
    "    tloss = 0\n",
    "    flag = 1\n",
    "    with torch.no_grad():\n",
    "        for validationIndex, (validationData, validationFrameIndex) in enumerate(validationloader, 0):\n",
    "            frame0, frameT, frame1 = validationData\n",
    "\n",
    "            I0 = frame0.to(device)\n",
    "            I1 = frame1.to(device)\n",
    "            IFrame = frameT.to(device)\n",
    "                        \n",
    "            \n",
    "            flowOut = flowComp(torch.cat((I0, I1), dim=1))\n",
    "            F_0_1 = flowOut[:,:2,:,:]\n",
    "            F_1_0 = flowOut[:,2:,:,:]\n",
    "\n",
    "            fCoeff = model.getFlowCoeff(validationFrameIndex, device)\n",
    "\n",
    "            F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0\n",
    "            F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0\n",
    "\n",
    "            g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)\n",
    "            g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1)\n",
    "            \n",
    "            intrpOut = ArbTimeFlowIntrp(torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1))\n",
    "                \n",
    "            F_t_0_f = intrpOut[:, :2, :, :] + F_t_0\n",
    "            F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1\n",
    "            V_t_0   = F.sigmoid(intrpOut[:, 4:5, :, :])\n",
    "            V_t_1   = 1 - V_t_0\n",
    "                \n",
    "            g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)\n",
    "            g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f)\n",
    "            \n",
    "            wCoeff = model.getWarpCoeff(validationFrameIndex, device)\n",
    "            \n",
    "            Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)\n",
    "            \n",
    "            # For tensorboard\n",
    "            if (flag):\n",
    "                retImg = torchvision.utils.make_grid([revNormalize(frame0[0]), revNormalize(frameT[0]), revNormalize(Ft_p.cpu()[0]), revNormalize(frame1[0])], padding=10)\n",
    "                flag = 0\n",
    "            \n",
    "            \n",
    "            #loss\n",
    "            recnLoss = L1_lossFn(Ft_p, IFrame)\n",
    "            \n",
    "            prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame))\n",
    "            \n",
    "            warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(g_I1_F_t_1, IFrame) + L1_lossFn(validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(validationFlowBackWarp(I1, F_0_1), I0)\n",
    "        \n",
    "            loss_smooth_1_0 = torch.mean(torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :]))\n",
    "            loss_smooth_0_1 = torch.mean(torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :]))\n",
    "            loss_smooth = loss_smooth_1_0 + loss_smooth_0_1\n",
    "            \n",
    "            \n",
    "            loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth\n",
    "            tloss += loss.item()\n",
    "            \n",
    "            #psnr\n",
    "            MSE_val = MSE_LossFn(Ft_p, IFrame)\n",
    "            psnr += (10 * log10(1 / MSE_val.item()))\n",
    "            \n",
    "    return (psnr / len(validationloader)), (tloss / len(validationloader)), retImg"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Eh1LB1ufZziF"
   },
   "source": [
    "### Test validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "axBjslWlot7I"
   },
   "outputs": [],
   "source": [
    "a, b, c = validate()\n",
    "print(a, b, c.size())\n",
    "plt.imshow(c.permute(1, 2, 0).numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "1PIFbXuKpBBe"
   },
   "source": [
    "### Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "gWt-nlx2MSOk"
   },
   "outputs": [],
   "source": [
    "if TRAINING_CONTINUE:\n",
    "    dict1 = torch.load(CHECKPOINT_PATH)\n",
    "    ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])\n",
    "    flowComp.load_state_dict(dict1['state_dictFC'])\n",
    "else:\n",
    "    dict1 = {'loss': [], 'valLoss': [], 'valPSNR': [], 'epoch': -1}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RbQnS_KNilbR"
   },
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "both",
    "colab": {},
    "colab_type": "code",
    "id": "QrAS6TmP11RW"
   },
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "start = time.time()\n",
    "cLoss   = dict1['loss']\n",
    "valLoss = dict1['valLoss']\n",
    "valPSNR = dict1['valPSNR']\n",
    "checkpoint_counter = 0\n",
    "\n",
    "### Main training loop\n",
    "for epoch in range(dict1['epoch'] + 1, EPOCHS):\n",
    "    clear_output()\n",
    "    print(\"Epoch: \", epoch)\n",
    "    \n",
    "    # Plots\n",
    "    if (epoch):\n",
    "        Plot(1, cLoss, 'red')\n",
    "        Plot(1, valLoss, 'blue')\n",
    "        Plot(2, valPSNR, 'green')\n",
    "        display(plt.gcf())\n",
    "    \n",
    "    # Append and reset\n",
    "    cLoss.append([])\n",
    "    valLoss.append([])\n",
    "    valPSNR.append([])\n",
    "    iLoss = 0\n",
    "    \n",
    "    # Increment scheduler count    \n",
    "    scheduler.step()\n",
    "    \n",
    "    for trainIndex, (trainData, trainFrameIndex) in enumerate(trainloader, 0):\n",
    "        \n",
    "\t\t## Getting the input and the target from the training set\n",
    "        frame0, frameT, frame1 = trainData\n",
    "        \n",
    "        I0 = frame0.to(device)\n",
    "        I1 = frame1.to(device)\n",
    "        IFrame = frameT.to(device)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        # Calculate flow between reference frames I0 and I1\n",
    "        flowOut = flowComp(torch.cat((I0, I1), dim=1))\n",
    "        \n",
    "        # Extracting flows between I0 and I1 - F_0_1 and F_1_0\n",
    "        F_0_1 = flowOut[:,:2,:,:]\n",
    "        F_1_0 = flowOut[:,2:,:,:]\n",
    "        \n",
    "        fCoeff = model.getFlowCoeff(trainFrameIndex, device)\n",
    "        \n",
    "        # Calculate intermediate flows\n",
    "        F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0\n",
    "        F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0\n",
    "        \n",
    "        # Get intermediate frames from the intermediate flows\n",
    "        g_I0_F_t_0 = trainFlowBackWarp(I0, F_t_0)\n",
    "        g_I1_F_t_1 = trainFlowBackWarp(I1, F_t_1)\n",
    "        \n",
    "        # Calculate optical flow residuals and visibility maps\n",
    "        intrpOut = ArbTimeFlowIntrp(torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1))\n",
    "        \n",
    "        # Extract optical flow residuals and visibility maps\n",
    "        F_t_0_f = intrpOut[:, :2, :, :] + F_t_0\n",
    "        F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1\n",
    "        V_t_0   = F.sigmoid(intrpOut[:, 4:5, :, :])\n",
    "        V_t_1   = 1 - V_t_0\n",
    "        \n",
    "        # Get intermediate frames from the intermediate flows\n",
    "        g_I0_F_t_0_f = trainFlowBackWarp(I0, F_t_0_f)\n",
    "        g_I1_F_t_1_f = trainFlowBackWarp(I1, F_t_1_f)\n",
    "        \n",
    "        wCoeff = model.getWarpCoeff(trainFrameIndex, device)\n",
    "        \n",
    "        # Calculate final intermediate frame \n",
    "        Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)\n",
    "        \n",
    "        # Loss\n",
    "        recnLoss = L1_lossFn(Ft_p, IFrame)\n",
    "            \n",
    "        prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame))\n",
    "        \n",
    "        warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(g_I1_F_t_1, IFrame) + L1_lossFn(trainFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(trainFlowBackWarp(I1, F_0_1), I0)\n",
    "        \n",
    "        loss_smooth_1_0 = torch.mean(torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :]))\n",
    "        loss_smooth_0_1 = torch.mean(torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :]))\n",
    "        loss_smooth = loss_smooth_1_0 + loss_smooth_0_1\n",
    "          \n",
    "        # Total Loss - Coefficients 204 and 102 are used instead of 0.8 and 0.4\n",
    "        # since the loss in paper is calculated for input pixels in range 0-255\n",
    "        # and the input to our network is in range 0-1\n",
    "        loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth\n",
    "        \n",
    "        # Backpropagate\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        iLoss += loss.item()\n",
    "               \n",
    "        # Validation and progress every `PROGRESS_ITER` iterations\n",
    "        if ((trainIndex % PROGRESS_ITER) == PROGRESS_ITER - 1):\n",
    "            end = time.time()\n",
    "            \n",
    "            psnr, vLoss, valImg = validate()\n",
    "            \n",
    "            valPSNR[epoch].append(psnr)\n",
    "            valLoss[epoch].append(vLoss)\n",
    "            \n",
    "            #Tensorboard\n",
    "            itr = trainIndex + epoch * (len(trainloader))\n",
    "            \n",
    "            writer.add_scalars('Loss', {'trainLoss': iLoss/PROGRESS_ITER,\n",
    "                                        'validationLoss': vLoss}, itr)\n",
    "            writer.add_scalar('PSNR', psnr, itr)\n",
    "            \n",
    "            writer.add_image('Validation',valImg , itr)\n",
    "            #####\n",
    "            \n",
    "            endVal = time.time()\n",
    "            \n",
    "            print(\" Loss: %0.6f  Iterations: %4d/%4d  TrainExecTime: %0.1f  ValLoss:%0.6f  ValPSNR: %0.4f  ValEvalTime: %0.2f LearningRate: %f\" % (iLoss / PROGRESS_ITER, trainIndex, len(trainloader), end - start, vLoss, psnr, endVal - end, get_lr(optimizer)))\n",
    "            \n",
    "            \n",
    "            cLoss[epoch].append(iLoss/PROGRESS_ITER)\n",
    "            iLoss = 0\n",
    "            start = time.time()\n",
    "    \n",
    "    # Create checkpoint after every `CHECKPOINT_EPOCH` epochs\n",
    "    if ((epoch % CHECKPOINT_EPOCH) == CHECKPOINT_EPOCH - 1):\n",
    "        dict1 = {\n",
    "                'Detail':\"End to end Super SloMo.\",\n",
    "                'epoch':epoch,\n",
    "                'timestamp':datetime.datetime.now(),\n",
    "                'trainBatchSz':TRAIN_BATCH_SIZE,\n",
    "                'validationBatchSz':VALIDATION_BATCH_SIZE,\n",
    "                'learningRate':get_lr(optimizer),\n",
    "                'loss':cLoss,\n",
    "                'valLoss':valLoss,\n",
    "                'valPSNR':valPSNR,\n",
    "                'state_dictFC': flowComp.state_dict(),\n",
    "                'state_dictAT': ArbTimeFlowIntrp.state_dict(),\n",
    "                }\n",
    "        torch.save(dict1, CHECKPOINT_DIR + \"/SuperSloMo\" + str(checkpoint_counter) + \".ckpt\")\n",
    "        checkpoint_counter += 1\n",
    "    plt.close('all')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "train.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}